{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gumbel Softmax / Concrete VAE with BayesFlow\n",
    "\n",
    "Implements a categorical VAE using the technique introduced in [The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables (Maddison et al. 2016)](https://arxiv.org/abs/1611.00712) and [Categorical Reparameterization with Gumbel-Softmax (Jang et al. 2016)](https://arxiv.org/abs/1611.01144). The VAE architecture shown here are a bit different than the models presented in the papers, this one has 1 stochastic 20x10-ary layer with 2-layer deterministic encoder/decoders and a fixed prior.\n",
    "\n",
    "17 Feb 2017"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tensorflow.examples.tutorials.mnist import input_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "slim=tf.contrib.slim\n",
    "Bernoulli = tf.contrib.distributions.Bernoulli\n",
    "OneHotCategorical = tf.contrib.distributions.OneHotCategorical\n",
    "RelaxedOneHotCategorical = tf.contrib.distributions.RelaxedOneHotCategorical"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# black-on-white MNIST (harder to learn than white-on-black MNIST)\n",
    "mnist = input_data.read_data_sets('MNIST_data', one_hot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size=100\n",
    "tau0=1.0 # initial temperature\n",
    "K=10 # number of classes\n",
    "N=200//K # number of categorical distributions\n",
    "straight_through=False # if True, use Straight-through Gumbel-Softmax\n",
    "kl_type='relaxed' # choose between ('relaxed', 'categorical')\n",
    "learn_temp=False "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "x=tf.placeholder(tf.float32, shape=(batch_size,784), name='x')\n",
    "net = tf.cast(tf.random_uniform(tf.shape(x)) < x, x.dtype) # dynamic binarization\n",
    "net = slim.stack(net,slim.fully_connected,[512,256])\n",
    "logits_y = tf.reshape(slim.fully_connected(net,K*N,activation_fn=None),[-1,N,K])\n",
    "tau = tf.Variable(tau0,name=\"temperature\",trainable=learn_temp)\n",
    "q_y = RelaxedOneHotCategorical(tau,logits_y)\n",
    "y = q_y.sample()\n",
    "if straight_through:\n",
    "  y_hard = tf.cast(tf.one_hot(tf.argmax(y,-1),K), y.dtype)\n",
    "  y = tf.stop_gradient(y_hard - y) + y\n",
    "net = slim.flatten(y)\n",
    "net = slim.stack(net,slim.fully_connected,[256,512])\n",
    "logits_x = slim.fully_connected(net,784,activation_fn=None)\n",
    "p_x = Bernoulli(logits=logits_x)\n",
    "x_mean = p_x.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "recons = tf.reduce_sum(p_x.log_prob(x),1)\n",
    "logits_py = tf.ones_like(logits_y) * 1./K\n",
    "\n",
    "if kl_type=='categorical' or straight_through:\n",
    "  # Analytical KL with Categorical prior\n",
    "  p_cat_y = OneHotCategorical(logits=logits_py)\n",
    "  q_cat_y = OneHotCategorical(logits=logits_y)\n",
    "  KL_qp = tf.contrib.distributions.kl(q_cat_y, p_cat_y)\n",
    "else:\n",
    "  # Monte Carlo KL with Relaxed prior\n",
    "  p_y = RelaxedOneHotCategorical(tau,logits=logits_py)\n",
    "  KL_qp = q_y.log_prob(y) - p_y.log_prob(y)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "KL = tf.reduce_sum(KL_qp,1)\n",
    "mean_recons = tf.reduce_mean(recons)\n",
    "mean_KL = tf.reduce_mean(KL)\n",
    "loss = -tf.reduce_mean(recons-KL)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_op=tf.train.AdamOptimizer(learning_rate=3e-4).minimize(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "data = []\n",
    "with tf.train.MonitoredSession() as sess:\n",
    "  for i in range(1,50000):\n",
    "    batch = mnist.train.next_batch(batch_size)\n",
    "    res = sess.run([train_op, loss, tau, mean_recons, mean_KL], {x : batch[0]})\n",
    "    if i % 100 == 1:\n",
    "      data.append([i] + res[1:])\n",
    "    if i % 1000 == 1:\n",
    "      print('Step %d, Loss: %0.3f' % (i,res[1]))\n",
    "  # end training - do an eval\n",
    "  batch = mnist.test.next_batch(batch_size)\n",
    "  np_x = sess.run(x_mean, {x : batch[0]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "data = np.array(data).T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "f,axarr=plt.subplots(1,4,figsize=(18,6))\n",
    "axarr[0].plot(data[0],data[1])\n",
    "axarr[0].set_title('Loss')\n",
    "\n",
    "axarr[1].plot(data[0],data[2])\n",
    "axarr[1].set_title('Temperature')\n",
    "\n",
    "axarr[2].plot(data[0],data[3])\n",
    "axarr[2].set_title('Recons')\n",
    "\n",
    "axarr[3].plot(data[0],data[4])\n",
    "axarr[3].set_title('KL')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "tmp = np.reshape(np_x,(-1,280,28)) # (10,280,28)\n",
    "img = np.hstack([tmp[i] for i in range(10)])\n",
    "plt.imshow(img)\n",
    "plt.grid('off')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
